[tx] General implementation of trainable Hyper Connections#1008
[tx] General implementation of trainable Hyper Connections#1008tanmaysachan wants to merge 8 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a general implementation of Hyper Connections as an extension to the transformer layers. The changes are mainly in tx/layers/connectors.py where the Connector module is defined, and in tx/models/deepseekv3.py to integrate it into the decoder layers.
My review found a couple of issues:
- An unused
trainableparameter in theConnectorclass which should be removed for clarity. - A bug in
DeepseekV3Modelwhen handling intermediate hidden states forexpansion_rate > 1, wheresqueeze()is used incorrectly.
Overall, the implementation of the Hyper Connections logic seems to follow the intended pattern of pre/post processing around existing attention and MLP blocks. The changes are well-contained. Addressing the mentioned points will improve the robustness and clarity of the implementation.
| for layer_idx, layer in enumerate(self.layers): | ||
| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
| all_hidden_states.append(hidden_states.squeeze()) |
There was a problem hiding this comment.
hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.
A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.
| all_hidden_states.append(hidden_states.squeeze()) | |
| all_hidden_states.append(hidden_states.mean(axis=-2)) |
| hidden_dim: int, | ||
| expansion_rate: int, | ||
| *, | ||
| trainable: bool = False, |
There was a problem hiding this comment.
skyrl-tx/tx/layers/layernorm.py
Outdated
| self.eps = eps | ||
| self.weight = Param( | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs |
There was a problem hiding this comment.
Temporary, testing
There was a problem hiding this comment.
https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html
Torch also initalizes to one by default
|
This looks very elegant, thanks a lot for putting it together! Have you tried to do any end-to-end runs yet / studied the performance, both in terms of learning dynamics / accuracy, as well as how much slowdown it incurs :) |
|
Just waiting for the weekend to give it a spin 😅 I'll give Qwen0.6B a shot on an A/H100 |
|
Sounds great! I'm putting together the 0.3.0 release at the moment, so it will probably need to wait then, but 0.3.1 should come relatively soon thereafter, so it is not a problem. I'll put a callout in the release blog anyways, if somebody wants to try it out, they can just apply the diff themselves given how simple this is :) |
Addresses #952
This PR is a general implementation of Hyper connections.
This is supposed to be an extension like Lora, where the default case mimics a standard residual connection with identity mappings.
Default case - Trainable is false. Expansion rate is 1.
For expansion rate > 1
These matrices preserve identity mapping. So expansion rate > 1 but untrainable still results in the the same outputs.
Todos
Future work